from core.c_dataset import CassavaDataset


# num_per_disease: 1087, 2189, 2386, 13158, 2577

# split
# num_per_disease:  [975, 1979, 2158, 11934, 2351]
# num_per_disease:  [112, 210, 228, 1224, 226]

if __name__ == '__main__':
    root = "/home/handewei/data/cassava-leaf/"

    trainSet = CassavaDataset(root, "train_nohead.csv")
    num_per_disease = [0, 0, 0, 0, 0]
    for i in range(trainSet.__len__()):
        img, label, fileName = trainSet[i]
        num_per_disease[int(label)] += 1

    print("num_per_disease: ", num_per_disease)